from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

def knn_iris():
    """
    用KNN算法会鸢尾花进行分类
    :return:
    """
    # 1.获取数据
    iris = load_iris()
    print(iris)
    # 2.划分数据集
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)
    print(x_train)
    print(x_test)
    print(y_train)
    print(y_test)
    # 3.特征工程：标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test) #测试集使用训练集的数据（平均值）等进行标准化
    # 4.KNN 算法预估器
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)
    # 5.模型评估
    # 方法一：直接比对预测值与真实值
    y_predict = estimator.predict(x_test)
    print("y_predict:\n", y_predict)
    print("y_test:\n", y_test)
    print("预测值与真实值比对：\n",  y_predict == y_test)
    # 方法二：计算准确率
    score = estimator.score(x_test, y_test)
    print("准确率为：\n", score)

    return None

if __name__ == "__main__":
    # 代码11：用KNN算法会鸢尾花进行分类
    knn_iris()